{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from __future__ import division\n",
    "from operator import itemgetter\n",
    "from itertools import count\n",
    "from collections import Counter, defaultdict\n",
    "import random\n",
    "import numpy as np\n",
    "import dynet as dy\n",
    "import re"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# config\n",
    "WORD_DIM = 64\n",
    "LSTM_DIM = 64\n",
    "ACTION_DIM = 32"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# represents a bidirectional mapping from strings to ints\n",
    "class Vocab(object):\n",
    "    def __init__(self, w2i):\n",
    "        self.w2i = dict(w2i)\n",
    "        self.i2w = {i:w for w,i in w2i.iteritems()}\n",
    "\n",
    "    @classmethod\n",
    "    def from_list(cls, words):\n",
    "        w2i = {}\n",
    "        idx = 0\n",
    "        for word in words:\n",
    "            w2i[word] = idx\n",
    "            idx += 1\n",
    "        return Vocab(w2i)\n",
    "\n",
    "    @classmethod\n",
    "    def from_file(cls, vocab_fname):\n",
    "        words = []\n",
    "        with file(vocab_fname) as fh:\n",
    "            for line in fh:\n",
    "                line.strip()\n",
    "                word, count = line.split()\n",
    "                words.append(word)\n",
    "        return Vocab.from_list(words)\n",
    "\n",
    "    def size(self): return len(self.w2i.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# format:\n",
    "# John left . ||| SHIFT SHIFT REDUCE_L SHIFT REDUCE_R\n",
    "def read_oracle(fname, vw, va):\n",
    "    with file(fname) as fh:\n",
    "        for line in fh:\n",
    "            line = line.strip()\n",
    "            ssent, sacts = re.split(r' \\|\\|\\| ', line)\n",
    "            sent = [vw.w2i[x] for x in ssent.split()]\n",
    "            acts = [va.w2i[x] for x in sacts.split()]\n",
    "            sent.reverse()\n",
    "            acts.reverse()\n",
    "            yield (sent, acts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class StackRNN(object):\n",
    "    def __init__(self, rnn, p_empty_embedding = None):\n",
    "        self.s = [(rnn.initial_state(), None)]\n",
    "        self.empty = None\n",
    "        if p_empty_embedding:\n",
    "            self.empty = dy.parameter(p_empty_embedding)\n",
    "    def push(self, expr, extra=None):\n",
    "        self.s.append((self.s[-1][0].add_input(expr), extra))\n",
    "    def pop(self):\n",
    "        return self.s.pop()[1] # return \"extra\" (i.e., whatever the caller wants or None)\n",
    "    def embedding(self):\n",
    "        # work around since inital_state.output() is None\n",
    "        return self.s[-1][0].output() if len(self.s) > 1 else self.empty\n",
    "    def __len__(self):\n",
    "        return len(self.s) - 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# actions the parser can take\n",
    "acts = ['SHIFT', 'REDUCE_L', 'REDUCE_R']\n",
    "vocab_acts = Vocab.from_list(acts)\n",
    "SHIFT = vocab_acts.w2i['SHIFT']\n",
    "REDUCE_L = vocab_acts.w2i['REDUCE_L']\n",
    "REDUCE_R = vocab_acts.w2i['REDUCE_R']\n",
    "NUM_ACTIONS = vocab_acts.size()\n",
    "\n",
    "class TransitionParser(object):\n",
    "    def __init__(self, model, vocab):\n",
    "        self.vocab = vocab\n",
    "        # syntactic composition\n",
    "        self.pW_comp = model.add_parameters((LSTM_DIM, LSTM_DIM * 2))\n",
    "        self.pb_comp = model.add_parameters(LSTM_DIM)\n",
    "        # parser state to hidden\n",
    "        self.pW_s2h = model.add_parameters((LSTM_DIM, LSTM_DIM * 2))\n",
    "        self.pb_s2h = model.add_parameters(LSTM_DIM)\n",
    "        # hidden to action\n",
    "        self.pW_act = model.add_parameters((NUM_ACTIONS, LSTM_DIM))\n",
    "        self.pb_act = model.add_parameters(NUM_ACTIONS)\n",
    "\n",
    "        # layers, in-dim, out-dim, model\n",
    "        self.buffRNN = dy.LSTMBuilder(1, WORD_DIM, LSTM_DIM, model)\n",
    "        self.stackRNN = dy.LSTMBuilder(1, WORD_DIM, LSTM_DIM, model)\n",
    "        self.pempty_buffer_emb = model.add_parameters(LSTM_DIM)\n",
    "        self.WORDS_LOOKUP = model.add_lookup_parameters((vocab.size(), WORD_DIM))\n",
    "\n",
    "    # Returns an expression of the loss for the sequence of actions.\n",
    "    # (that is, the oracle_actions if present or the predicted sequence otherwise)\n",
    "    def parse(self, tokens, oracle_actions=None):\n",
    "        def _valid_actions(stack, buffer):\n",
    "            valid_actions = []\n",
    "            if len(buffer) > 0:\n",
    "                valid_actions += [SHIFT]\n",
    "            if len(stack) >= 2:\n",
    "                valid_actions += [REDUCE_L, REDUCE_R]\n",
    "            return valid_actions\n",
    "\n",
    "        dy.renew_cg() # each sentence gets its own graph\n",
    "        if oracle_actions: oracle_actions = list(oracle_actions)\n",
    "        buffer = StackRNN(self.buffRNN, self.pempty_buffer_emb)\n",
    "        stack = StackRNN(self.stackRNN)\n",
    "    \n",
    "        # Put the parameters in the cg\n",
    "        W_comp = dy.parameter(self.pW_comp) # syntactic composition\n",
    "        b_comp = dy.parameter(self.pb_comp)\n",
    "        W_s2h = dy.parameter(self.pW_s2h)   # state to hidden\n",
    "        b_s2h = dy.parameter(self.pb_s2h)\n",
    "        W_act = dy.parameter(self.pW_act)   # hidden to action\n",
    "        b_act = dy.parameter(self.pb_act)\n",
    "    \n",
    "        # We will keep track of all the losses we accumulate during parsing.\n",
    "        # If some decision is unambiguous because it's the only thing valid given\n",
    "        # the parser state, we will not model it. We only model what is ambiguous.\n",
    "        losses = []\n",
    "        \n",
    "        # push the tokens onto the buffer (tokens is in reverse order)\n",
    "        for tok in tokens:\n",
    "            tok_embedding = self.WORDS_LOOKUP[tok]\n",
    "            buffer.push(tok_embedding, (tok_embedding, self.vocab.i2w[tok]))\n",
    "\n",
    "        while not (len(stack) == 1 and len(buffer) == 0):\n",
    "            # compute probability of each of the actions and choose an action\n",
    "            # either from the oracle or if there is no oracle, based on the model\n",
    "            valid_actions = _valid_actions(stack, buffer)\n",
    "            log_probs = None\n",
    "            action = valid_actions[0]\n",
    "            if len(valid_actions) > 1:\n",
    "                p_t = dy.concatenate([buffer.embedding(), stack.embedding()])\n",
    "                h = dy.tanh(W_s2h * p_t + b_s2h)\n",
    "                logits = W_act * h + b_act\n",
    "                log_probs = dy.log_softmax(logits, valid_actions)\n",
    "                if oracle_actions is None:\n",
    "                    action = np.argmax(log_probs.npvalue())\n",
    "            if oracle_actions is not None:\n",
    "                action = oracle_actions.pop()\n",
    "            if log_probs is not None:\n",
    "                # append the action-specific loss\n",
    "                losses.append(dy.pick(log_probs, action))\n",
    "\n",
    "            # execute the action to update the parser state\n",
    "            if action == SHIFT:\n",
    "                tok_embedding, token = buffer.pop()\n",
    "                stack.push(tok_embedding, (tok_embedding, token))\n",
    "            else: # one of the REDUCE actions\n",
    "                right = stack.pop() # pop a stack state\n",
    "                left = stack.pop()  # pop another stack state\n",
    "                # figure out which is the head and which is the modifier\n",
    "                head, modifier = (left, right) if action == REDUCE_R else (right, left)\n",
    "        \n",
    "                # compute composed representation\n",
    "                head_rep, head_tok = head\n",
    "                mod_rep, mod_tok = modifier\n",
    "                composed_rep = dy.tanh(W_comp * dy.concatenate([head_rep, mod_rep]) + b_comp)\n",
    "                \n",
    "                stack.push(composed_rep, (composed_rep, head_tok))\n",
    "                if oracle_actions is None:\n",
    "                    print('{0} --> {1}'.format(head_tok, mod_tok))\n",
    "\n",
    "        # the head of the tree that remains at the top of the stack is the root\n",
    "        if oracle_actions is None:\n",
    "            head = stack.pop()[1]\n",
    "            print('ROOT --> {0}'.format(head))\n",
    "        return -dy.esum(losses) if losses else None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# load training and dev data\n",
    "vocab_words = Vocab.from_file('data/vocab.txt')\n",
    "train = list(read_oracle('data/small-train.unk.txt', vocab_words, vocab_acts))\n",
    "dev = list(read_oracle('data/small-dev.unk.txt', vocab_words, vocab_acts))\n",
    "\n",
    "model = dy.Model()\n",
    "trainer = dy.AdamTrainer(model)\n",
    "\n",
    "tp = TransitionParser(model, vocab_words)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[validation] epoch 0.0: per-word loss: 2.35509942849\n",
      "epoch 0.01: per-word loss: 1.48255809603\n",
      "epoch 0.02: per-word loss: 1.08090971986\n",
      "epoch 0.03: per-word loss: 0.944152983054\n",
      "epoch 0.04: per-word loss: 0.953884831015\n",
      "epoch 0.05: per-word loss: 0.834949359487\n",
      "epoch 0.06: per-word loss: 0.811419209329\n",
      "epoch 0.07: per-word loss: 0.739562879959\n",
      "epoch 0.08: per-word loss: 0.72608875857\n",
      "epoch 0.09: per-word loss: 0.715399681898\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-9-5eacd1fbaf60>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     15\u001b[0m                 \u001b[0mdev_words\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     16\u001b[0m                 \u001b[0;32mif\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 17\u001b[0;31m                     \u001b[0mdev_loss\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mscalar_value\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     18\u001b[0m             \u001b[0;32mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'[validation] epoch {}: per-word loss: {}'\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdev_loss\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mdev_words\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     19\u001b[0m             \u001b[0mvalidation_losses\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdev_loss\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "instances_processed = 0\n",
    "validation_losses = []\n",
    "for epoch in range(5):\n",
    "    random.shuffle(train)\n",
    "    words = 0\n",
    "    total_loss = 0.0\n",
    "    for (s,a) in train:\n",
    "        # periodically report validation loss\n",
    "        e = instances_processed / len(train)\n",
    "        if instances_processed % 1000 == 0:\n",
    "            dev_words = 0\n",
    "            dev_loss = 0.0\n",
    "            for (ds, da) in dev:\n",
    "                loss = tp.parse(ds, da)\n",
    "                dev_words += len(ds)\n",
    "                if loss is not None:\n",
    "                    dev_loss += loss.scalar_value()\n",
    "            print('[validation] epoch {}: per-word loss: {}'.format(e, dev_loss / dev_words))\n",
    "            validation_losses.append(dev_loss)\n",
    "\n",
    "        # report training loss\n",
    "        if instances_processed % 100 == 0 and words > 0:\n",
    "            print('epoch {}: per-word loss: {}'.format(e, total_loss / words))\n",
    "            words = 0\n",
    "            total_loss = 0.0\n",
    "    \n",
    "        # here we do training\n",
    "        loss = tp.parse(s, a) # returns None for 1-word sentencs (it's clear how to parse them)\n",
    "        words += len(s)\n",
    "        instances_processed += 1\n",
    "        if loss is not None:\n",
    "            total_loss += loss.scalar_value()\n",
    "            loss.backward()\n",
    "            trainer.update()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fun --> is\n",
      "fun --> Austin\n",
      "in --> fun\n",
      "<unk> --> in\n",
      "<unk> --> .\n",
      "ROOT --> <unk>\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "exprssion 316/1496"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "s = 'Parsing in Austin is fun .'\n",
    "UNK = vocab_words.w2i['<unk>']\n",
    "toks = [vocab_words.w2i[x] if x in vocab_words.w2i else UNK for x in s.split()]\n",
    "toks.reverse()\n",
    "tp.parse(toks)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
